-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SCF][PIPELINE] Handle the case when values from the peeled prologue may escape out of the loop #105755
[SCF][PIPELINE] Handle the case when values from the peeled prologue may escape out of the loop #105755
Conversation
@llvm/pr-subscribers-mlir-scf @llvm/pr-subscribers-mlir Author: None (pawelszczerbuk) ChangesPreviously the values in the peeled prologue that weren't treated with the Full diff: https://github.com/llvm/llvm-project/pull/105755.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index cc1a22d0d48a18..d8e1cc0ecef88e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -268,7 +268,7 @@ cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
}
void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
- // Initialize the iteration argument to the loop initiale values.
+ // Initialize the iteration argument to the loop initial values.
for (auto [arg, operand] :
llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
setValueMapping(arg, operand.get(), 0);
@@ -320,16 +320,26 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
if (annotateFn)
annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
- setValueMapping(op->getResult(destId), newOp->getResult(destId),
- i - stages[op]);
+ Value source = newOp->getResult(destId);
// If the value is a loop carried dependency update the loop argument
- // mapping.
for (OpOperand &operand : yield->getOpOperands()) {
if (operand.get() != op->getResult(destId))
continue;
+ if (predicates[predicateIdx] &&
+ !forOp.getResult(operand.getOperandNumber()).use_empty()) {
+ // If the value is used outside the loop, we need to make sure we
+ // return the correct version of it.
+ Value prevValue = valueMapping
+ [forOp.getRegionIterArgs()[operand.getOperandNumber()]]
+ [i - stages[op]];
+ source = rewriter.create<arith::SelectOp>(
+ loc, predicates[predicateIdx], source, prevValue);
+ }
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
- newOp->getResult(destId), i - stages[op] + 1);
+ source, i - stages[op] + 1);
}
+ setValueMapping(op->getResult(destId), newOp->getResult(destId),
+ i - stages[op]);
}
}
}
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 46e7feca4329ee..9687f80f5ddfc8 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -703,18 +703,26 @@ func.func @distance_1_use(%A: memref<?xf32>, %result: memref<?xf32>) {
// -----
// NOEPILOGUE-LABEL: stage_0_value_escape(
-func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>) {
+func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub: index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
- %c4 = arith.constant 4 : index
%cf = arith.constant 1.0 : f32
-// NOEPILOGUE: %[[C3:.+]] = arith.constant 3 : index
-// NOEPILOGUE: %[[A:.+]] = arith.addf
-// NOEPILOGUE: scf.for %[[IV:.+]] = {{.*}} iter_args(%[[ARG:.+]] = %[[A]],
-// NOEPILOGUE: %[[C:.+]] = arith.cmpi slt, %[[IV]], %[[C3]] : index
-// NOEPILOGUE: %[[S:.+]] = arith.select %[[C]], %{{.+}}, %[[ARG]] : f32
-// NOEPILOGUE: scf.yield %[[S]]
- %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
+// NOEPILOGUE: %[[UB:[^,]+]]: index)
+// NOEPILOGUE-DAG: %[[C0:.+]] = arith.constant 0 : index
+// NOEPILOGUE-DAG: %[[C1:.+]] = arith.constant 1 : index
+// NOEPILOGUE-DAG: %[[CF:.+]] = arith.constant 1.000000e+00
+// NOEPILOGUE: %[[CND0:.+]] = arith.cmpi sgt, %[[UB]], %[[C0]]
+// NOEPILOGUE: scf.if
+// NOEPILOGUE: %[[IF:.+]] = scf.if %[[CND0]]
+// NOEPILOGUE: %[[A:.+]] = arith.addf
+// NOEPILOGUE: scf.yield %[[A]]
+// NOEPILOGUE: %[[S0:.+]] = arith.select %[[CND0]], %[[IF]], %[[CF]]
+// NOEPILOGUE: scf.for %[[IV:.+]] = {{.*}} iter_args(%[[ARG:.+]] = %[[S0]],
+// NOEPILOGUE: %[[UB_1:.+]] = arith.subi %[[UB]], %[[C1]] : index
+// NOEPILOGUE: %[[CND1:.+]] = arith.cmpi slt, %[[IV]], %[[UB_1]] : index
+// NOEPILOGUE: %[[S1:.+]] = arith.select %[[CND1]], %{{.+}}, %[[ARG]] : f32
+// NOEPILOGUE: scf.yield %[[S1]]
+ %r = scf.for %i0 = %c0 to %ub step %c1 iter_args(%arg0 = %cf) -> (f32) {
%A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : memref<?xf32>
%A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
memref.store %A1_elem, %result[%c0] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref<?xf32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch, thanks!
…may escape out of the loop (llvm#105755) Previously the values in the peeled prologue that weren't treated with the `predicateFn` were passed to the loop body without any other predication. If those values are later used outside of the loop body, they may be incorrect if the num iterations is smaller than num stages - 1. We need similar masking for those, as is done in the main loop body, using already existing predicates.
…may escape out of the loop (#105755) Previously the values in the peeled prologue that weren't treated with the `predicateFn` were passed to the loop body without any other predication. If those values are later used outside of the loop body, they may be incorrect if the num iterations is smaller than num stages - 1. We need similar masking for those, as is done in the main loop body, using already existing predicates.
…may escape out of the loop (llvm#105755) Previously the values in the peeled prologue that weren't treated with the `predicateFn` were passed to the loop body without any other predication. If those values are later used outside of the loop body, they may be incorrect if the num iterations is smaller than num stages - 1. We need similar masking for those, as is done in the main loop body, using already existing predicates.
Previously the values in the peeled prologue that weren't treated with the
predicateFn
were passed to the loop body without any other predication. If those values are later used outside of the loop body, they may be incorrect if the num iterations is smaller than num stages - 1. We need similar masking for those, as is done in the main loop body, using already existing predicates.